#include "RemoveTestNoUseOps.hpp"
bool RemoveTestNoUseOps::onExecute(std::unique_ptr<MNN::NetT>& net) const {
    const MNN::NetT* const netPtr = net.get();
    std::set<std::string> netOutputNames;
    for (auto& t : net->outputName) {
        netOutputNames.insert(t);
    }
    for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
        auto& op          = *iter;
        if (op->type == OpType_Input) {
            for (auto o : op->outputIndexes) {
                netOutputNames.insert(net->tensorName[o]);
            }
        }
    }

    std::unordered_set<int> removedInputs;
    for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
        auto& op          = *iter;
        bool shouldDelete = shouldDeleteJudge(op.get(), netPtr);
        if (!shouldDelete) {
            iter++;
            continue;
        }
        bool hasOutputName = false;
        for (auto o : op->outputIndexes) {
            if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
                hasOutputName = true;
                break;
            }
        }
        bool hasOutputFromInput = false;
        for (auto o : op->inputIndexes) {
            if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
                hasOutputFromInput = true;
                break;
            }
        }
        if (hasOutputFromInput && hasOutputName) {
            iter++;
            continue;
        }
        bool deleteOutput = shouldDeleteOutput(op.get());
        // Find the next op
        if (op->outputIndexes.empty() || op->inputIndexes.empty()) {
            iter = net->oplists.erase(iter);
            continue;
        }
        auto originInput  = op->inputIndexes[0];
        auto originOutputs = op->outputIndexes;
        if ((!deleteOutput) && hasOutputName) {
            bool valid = true;
            for (auto o : originOutputs) {
                if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
                    if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) {
                        valid = false;
                        break;
                    }
                    net->tensorName[originInput] = net->tensorName[o];
                }
            }
            if (!valid) {
                continue;
            }
        }
        for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
            auto& subOp = *subIter;
            if (deleteOutput) {
                for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) {
                    if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) {
                        iter = subOp->inputIndexes.erase(iter);
                        continue;
                    }
                    iter++;
                }
            } else {
                for (int v = 0; v < subOp->inputIndexes.size(); ++v) {
                    if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) {
                        subOp->inputIndexes[v] = originInput;
                    }
                }
            }
        }
        bool removeUselessInput = shouldRemoveUnusefulInputs(op.get());
        if (removeUselessInput) {
            for (int input : op->inputIndexes) {
                removedInputs.emplace(input);
            }
        }
        iter = net->oplists.erase(iter);
    }

    // Remove the op only if the reference counts of it's all outputs
    // are reduced to be zero.
    std::unordered_map<int, int/*reference count*/> uselessIndex;
    for (const auto& op : net->oplists) {
        for (int input : op->inputIndexes) {
            auto it = uselessIndex.find(input);
            if (it == uselessIndex.end()) {
                uselessIndex.emplace(input, 1);
            } else {
                ++it->second;
            }
        }
    }
    // Set reference count 1 for all net outputs.
    for (const auto& op : net->oplists) {
        for (int output : op->outputIndexes) {
            auto it = uselessIndex.find(output);
            if (it == uselessIndex.end()) {
                if (removedInputs.count(output)) {
                    uselessIndex.emplace(output, 0);
                } else {
                    uselessIndex.emplace(output, 1);
                }
            }
        }
    }

    bool needIteration = false;
    do {
        needIteration = false;
        for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
            auto& op     = *iter;
            bool useless = true;
            for (auto index : op->outputIndexes) {
                if (uselessIndex.at(index) > 0) {
                    useless = false;
                    break;
                }
            }
            if (!useless) {
                iter++;
                continue;
            }
            if (!op->inputIndexes.empty()) {
                for (auto index : op->inputIndexes) {
                    auto it = uselessIndex.find(index);
                    MNN_ASSERT(it != uselessIndex.end());
                    --it->second;
                }
                needIteration = true;
            }
            iter = net->oplists.erase(iter);
        }
    } while (needIteration);

    return true;
}
